{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Bi-LSTM-Torch",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "U0pw1iXA1YU1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "'''\n",
        "import torch\n",
        "import numpy as np\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data as Data\n",
        "\n",
        "dtype = torch.FloatTensor"
      ],
      "execution_count": 1,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "v2gK0CLz1bEJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sentence = (\n",
        "    'GitHub Actions makes it easy to automate all your software workflows '\n",
        "    'from continuous integration and delivery to issue triage and more'\n",
        ")\n",
        "\n",
        "word2idx = {w: i for i, w in enumerate(list(set(sentence.split())))}\n",
        "idx2word = {i: w for i, w in enumerate(list(set(sentence.split())))}\n",
        "n_class = len(word2idx) # classification problem\n",
        "max_len = len(sentence.split())\n",
        "n_hidden = 5\n",
        "\n",
        "def make_data(sentence):\n",
        "    input_batch = []\n",
        "    target_batch = []\n",
        "\n",
        "    words = sentence.split() # ['Github', 'Actions', 'makes', ...]\n",
        "    for i in range(max_len - 1): # i = 2\n",
        "        input = [word2idx[n] for n in words[:(i + 1)]] # input = [18 7 3]\n",
        "        input = input + [0] * (max_len - len(input)) # input = [18 7 3 0 'it', ..., 0]\n",
        "        target = word2idx[words[i + 1]] # target = [0]\n",
        "        input_batch.append(np.eye(n_class)[input])\n",
        "        target_batch.append(target)\n",
        "\n",
        "    return torch.Tensor(input_batch), torch.LongTensor(target_batch)\n",
        "\n",
        "# input_batch: [max_len - 1, max_len, n_class]\n",
        "input_batch, target_batch = make_data(sentence)\n",
        "dataset = Data.TensorDataset(input_batch, target_batch)\n",
        "loader = Data.DataLoader(dataset, 16, True)"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xRIwwk0SeXNg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class BiLSTM(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(BiLSTM, self).__init__()\n",
        "        self.lstm = nn.LSTM(input_size=n_class, hidden_size=n_hidden, bidirectional=True)\n",
        "        # fc\n",
        "        self.fc = nn.Linear(n_hidden * 2, n_class)\n",
        "\n",
        "    def forward(self, X):\n",
        "        # X: [batch_size, max_len, n_class]\n",
        "        batch_size = X.shape[0]\n",
        "        input = X.transpose(0, 1)  # input : [max_len, batch_size, n_class]\n",
        "\n",
        "        hidden_state = torch.randn(1*2, batch_size, n_hidden)   # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n",
        "        cell_state = torch.randn(1*2, batch_size, n_hidden)     # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]\n",
        "\n",
        "        outputs, (_, _) = self.lstm(input, (hidden_state, cell_state)) # [max_len, batch_size, n_hidden * 2]\n",
        "        outputs = outputs[-1]  # [batch_size, n_hidden * 2]\n",
        "        model = self.fc(outputs)  # model : [batch_size, n_class]\n",
        "        return model\n",
        "\n",
        "model = BiLSTM()\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.001)"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TpLARKzt2X0-",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 319
        },
        "outputId": "96f67c1d-0408-4a1f-aaba-2473724da15b"
      },
      "source": [
        "# Training\n",
        "for epoch in range(10000):\n",
        "    for x, y in loader:\n",
        "      pred = model(x)\n",
        "      loss = criterion(pred, y)\n",
        "      if (epoch + 1) % 1000 == 0:\n",
        "          print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
        "\n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "      optimizer.step()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 1000 cost = 1.761862\n",
            "Epoch: 1000 cost = 2.065904\n",
            "Epoch: 2000 cost = 1.212625\n",
            "Epoch: 2000 cost = 1.323285\n",
            "Epoch: 3000 cost = 1.037356\n",
            "Epoch: 3000 cost = 0.847872\n",
            "Epoch: 4000 cost = 0.834025\n",
            "Epoch: 4000 cost = 1.005801\n",
            "Epoch: 5000 cost = 0.654449\n",
            "Epoch: 5000 cost = 0.911304\n",
            "Epoch: 6000 cost = 0.580501\n",
            "Epoch: 6000 cost = 0.711413\n",
            "Epoch: 7000 cost = 0.463353\n",
            "Epoch: 7000 cost = 0.605977\n",
            "Epoch: 9000 cost = 0.401217\n",
            "Epoch: 9000 cost = 0.186404\n",
            "Epoch: 10000 cost = 0.237348\n",
            "Epoch: 10000 cost = 0.683901\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q3QPEI4r2ZeW",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "61b5e3bc-fca5-4bed-942f-db45d83c9d31"
      },
      "source": [
        "# Pred\n",
        "predict = model(input_batch).data.max(1, keepdim=True)[1]\n",
        "print(sentence)\n",
        "print([idx2word[n.item()] for n in predict.squeeze()])"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "GitHub Actions makes it easy to automate all your software workflows from continuous integration and delivery to issue triage and more\n",
            "['it', 'it', 'it', 'easy', 'to', 'automate', 'all', 'your', 'software', 'workflows', 'from', 'continuous', 'integration', 'and', 'delivery', 'to', 'issue', 'and', 'and', 'more']\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VJ6zY8qt-9xo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 5,
      "outputs": []
    }
  ]
}